import torch
from typing import Optional, Tuple, Any, Dict
import lightning as L
from pydantic import confloat
from .utils import registry
from . import net    
from transformers import get_linear_schedule_with_warmup
    
    
class NanoGPT(L.LightningModule):
    def __init__(self, 
                 n_embed: int, 
                 n_head: int, 
                 p_drop: float,  
                 n_block: int, 
                 block_size: int, 
                 lr: float,
                 num_warmup_steps: int,
                 weight_decay: float,
                 vocab_size: int):
        """
        Args:
            n_embed (int): 词元嵌入维度
            n_head (int): 多头注意力头数
            p_drop (float): dropout概率
            vocab_size (int): 词汇表大小
            n_block (int): 注意力单元层数
            block_size (int): 注意力单元输入序列长度
        """
        super().__init__()
        
        self.save_hyperparameters()
        
        self.net = net.GPT(n_embed=n_embed, 
                           n_head=n_head, 
                           p_drop=p_drop, 
                           n_block=n_block, 
                           block_size=block_size, 
                           vocab_size=vocab_size)
    
        
        
    def forward(self, input_ids: torch.Tensor, labels: torch.Tensor = None) -> torch.Tensor:
        
        logits, loss = self.net(input_ids=input_ids, labels=labels)
            
        return logits, loss
    
            
    def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        input_ids, labels = batch['input_ids'], batch['labels']
        _, loss = self(input_ids, labels)
        self.log("train/loss", loss, on_step=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor:
        input_ids, labels = batch['input_ids'], batch['labels']
        _, loss = self(input_ids, labels)
        self.log("val/loss", loss, on_epoch=True, prog_bar=True)
        return loss
    
    def get_linear_warmup_step_scheduler_config(self, optimizer, num_warmup_steps: int) -> Dict:
        total_steps = self.trainer.estimated_stepping_batches
        scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_training_steps=total_steps, num_warmup_steps=num_warmup_steps)
        scheduler_config = {'scheduler': scheduler, 'interval':'step'}
        return scheduler_config
            
    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)],
             'lr': self.hparams.lr, 'weight_decay': self.hparams.weight_decay},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)],
             'lr': self.hparams.lr, 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(grouped_parameters)
        scheduler_config = self.get_linear_warmup_step_scheduler_config(optimizer, self.hparams.num_warmup_steps)
        return [optimizer], [scheduler_config]
            
            
    def generate(self, ids: torch.Tensor, n_token: int, temprature: float = 1.0, top_k: Optional[int] = None):
        """推理生成词元

        Args:
            ids (torch.Tensor): 输入的词元id
            n_token (int): 要求生成的词元数
            temprature (float, optional): 温度系数. Defaults to 1.0.
            top_k (Optional[int], optional): 词表中选取的前几名. Defaults to None.
        """
        
        ids = self.net.generate(ids=ids, n_token=n_token, temprature=temprature, top_k=top_k)
        
        return ids
    


@registry.models('nanogpt.base')
def build_gpt_base(vocab_size: int = 50257, p_drop: confloat(gt=0, lt=1) = 0.1, lr: float = 2e-4, num_warmup_steps: int = 1000, weight_decay: float = 0.01):
    return NanoGPT(n_embed=768, n_head=12, p_drop=p_drop, n_block=12, block_size=1024, lr=lr, vocab_size=vocab_size, num_warmup_steps=num_warmup_steps, weight_decay=weight_decay)



@registry.models('nanogpt.medium')
def build_gpt_medium(vocab_size: int = 50257, p_drop: confloat(gt=0, lt=1) = 0.1, lr: float = 2e-4, num_warmup_steps: int = 1000, weight_decay: float = 0.01):
    return NanoGPT(n_embed=1024, n_head=16, p_drop=p_drop, n_block=24, block_size=1024, lr=lr, vocab_size=vocab_size, num_warmup_steps=num_warmup_steps, weight_decay=weight_decay)


class DeepSpeedNanoGPTModule(NanoGPT):
    # TODO: activation checkpointing (requires overriding forward)
    def __init__(self, fused_adam: bool = True, offload: bool = False, **kwargs: Any):
        if fused_adam and offload:
            raise RuntimeError(
                "Cannot use FusedAdam and CPUAdam at the same time! "
                "Please set either `fused_adam` or `offload` to False."
            )

        kwargs["device_type"] = "cuda" if fused_adam or kwargs.pop("device_type", "cpu") == "cuda" else "cpu"

        super().__init__(**kwargs)
        self.save_hyperparameters()


    def configure_sharded_model(self) -> None:
        self.nanogpt = net.GPT(n_embed=self.hparams.n_embed, 
                                   n_head=self.hparams.n_head, 
                                   p_drop=self.hparams.p_drop, 
                                   n_block=self.hparams.n_block,
                                   block_size=self.hparams.block_size,
                                   vocab_size=self.hparams.vocab_size)